{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "#XFeat minimal inference example"
      ],
      "metadata": {
        "id": "2tDj94al5GAJ"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Clone repository"
      ],
      "metadata": {
        "id": "X8MPXBro5IFv"
      }
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "tVkH1ChzNcLW",
        "outputId": "da9a9474-76bd-4b66-8ecd-8ba0022f030e"
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Cloning into 'accelerated_features'...\n",
            "remote: Enumerating objects: 27, done.\u001b[K\n",
            "remote: Counting objects: 100% (11/11), done.\u001b[K\n",
            "remote: Compressing objects: 100% (10/10), done.\u001b[K\n",
            "remote: Total 27 (delta 0), reused 5 (delta 0), pack-reused 16\u001b[K\n",
            "Receiving objects: 100% (27/27), 13.29 MiB | 23.03 MiB/s, done.\n",
            "Resolving deltas: 100% (1/1), done.\n",
            "/content/accelerated_features\n"
          ]
        }
      ],
      "source": [
        "!cd /content && git clone 'https://github.com/verlab/accelerated_features.git'\n",
        "%cd /content/accelerated_features"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Test on simple input (sparse setting)"
      ],
      "metadata": {
        "id": "32T-WzfU5NRH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import os\n",
        "import torch\n",
        "import tqdm\n",
        "\n",
        "from modules.xfeat import XFeat\n",
        "\n",
        "xfeat = XFeat()\n",
        "\n",
        "#Random input\n",
        "x = torch.randn(1,3,480,640)\n",
        "\n",
        "#Simple inference with batch = 1\n",
        "output = xfeat.detectAndCompute(x, top_k = 4096)[0]\n",
        "print(\"----------------\")\n",
        "print(\"keypoints: \", output['keypoints'].shape)\n",
        "print(\"descriptors: \", output['descriptors'].shape)\n",
        "print(\"scores: \", output['scores'].shape)\n",
        "print(\"----------------\\n\")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "o1TMnCEfNfvD",
        "outputId": "f59757f5-477a-4642-e955-7a5abefe3c21"
      },
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "loading weights from: /content/accelerated_features/modules/../weights/xfeat.pt\n",
            "----------------\n",
            "keypoints:  torch.Size([4096, 2])\n",
            "descriptors:  torch.Size([4096, 64])\n",
            "scores:  torch.Size([4096])\n",
            "----------------\n",
            "\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Stress test to check FPS on VGA (sparse setting)"
      ],
      "metadata": {
        "id": "8b9C09ya5UwL"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "x = torch.randn(1,3,480,640)\n",
        "# Stress test\n",
        "for i in tqdm.tqdm(range(100), desc=\"Stress test on VGA resolution\"):\n",
        "\toutput = xfeat.detectAndCompute(x, top_k = 4096)\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Zsjz-QT95ZrM",
        "outputId": "2df6f545-419f-4cc3-ad8b-bf5e12741dba"
      },
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "Stress test on VGA resolution: 100%|██████████| 100/100 [00:14<00:00,  6.74it/s]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Test with batched mode (sparse)"
      ],
      "metadata": {
        "id": "1jAl-ejS5du_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Batched mode\n",
        "x = torch.randn(4,3,480,640)\n",
        "outputs = xfeat.detectAndCompute(x, top_k = 4096)\n",
        "print(\"# detected features on each batch item:\", [len(o['keypoints']) for o in outputs])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lAarS8UH5gyg",
        "outputId": "883f13f8-3fac-48f2-c0a3-656a81b57f2c"
      },
      "execution_count": 4,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "# detected features on each batch item: [4096, 4096, 4096, 4096]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Matches two images with built-in MNN matcher (sparse mode)"
      ],
      "metadata": {
        "id": "H60iMAlh5nqP"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Match two images with sparse features\n",
        "x1 = torch.randn(1,3,480,640)\n",
        "x2 = torch.randn(1,3,480,640)\n",
        "mkpts_0, mkpts_1 = xfeat.match_xfeat(x1, x2)"
      ],
      "metadata": {
        "id": "6N-ZqoMZ5syf"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Matches two images with semi-dense matching, and batched mode (batch size = 4) for demonstration purpose"
      ],
      "metadata": {
        "id": "MOV4vZDp5v9_"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "# Create 4 image pairs\n",
        "x1 = torch.randn(4,3,480,640)\n",
        "x2 = torch.randn(4,3,480,640)\n",
        "\n",
        "#Obtain matches for each batch item\n",
        "matches_list = xfeat.match_xfeat_star(x1, x2, top_k = 5000)\n",
        "print('number of img pairs', len(matches_list))\n",
        "print(matches_list[0].shape) # -> output is (x1,y1,x2,y2)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Axe0o6U85zGV",
        "outputId": "e1257959-24fc-4194-b2f1-ee06cf450b24"
      },
      "execution_count": 6,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "number of img pairs 4\n",
            "torch.Size([182, 4])\n"
          ]
        }
      ]
    }
  ]
}